from torch import nn
from abc import ABC
from abc import abstractmethod

class AbstractTrasnformerLayer(ABC):
    @abstractmethod
    def __init__(self,
            embed_dim,
            num_heads,
            dropout,
            norm,
            norm_first: bool,
            causal: bool,
    ):
        pass

    @abstractmethod
    def forward(self, x, attn_mask, output_attentions):
        pass

class VanillaTransformerLayer(nn.Module, AbstractTrasnformerLayer):
    def __init__(
            self,
            embed_dim,
            num_heads,
            dropout = 0.0,
            norm = 'layernorm',
            norm_first=False,
            causal=False,
    ):
        super().__init__()
        assert norm=='layernorm', 'Vanilla transformer only supports layernorm.'
        assert causal==False, 'Vanilla transformer does not supports causal inference.'
        self.layer = nn.TransformerEncoderLayer(embed_dim, num_heads, embed_dim*4,
                                                dropout, activation='gelu', norm_first=norm_first)
        self.support_output_attentions = False

    def forward(self, x, attn_mask=None, output_attentions=False):
        assert output_attentions == False, 'output_attentions not implemented for VanillaTransformer'
        x = x.unsqueeze(1)
        return self.layer(x, attn_mask)[:, 0, :]